// Copyright (c) 2021 RonxBulld
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#ifndef YAMQ_YAMQ_H
#define YAMQ_YAMQ_H

#include <iostream>
#include <map>
#include <mutex>
#include <set>
#include <vector>
#include <thread>
#include <queue>
#include <future>
#include <functional>
#include <atomic>
#include <type_traits>
#include <memory>

namespace yamq {
    class KVdb : public std::map<std::string, std::string> {
    private:
        using base_ty = std::map<std::string, std::string>;
        const std::string empty_holder_;
    public:
        KVdb &operator=(const KVdb &other) {
            base_ty::operator=(other);
            return *this;
        }
        const std::string &operator[](const std::string &key) const {
            auto found = this->find(key);
            if (found != this->end()) {
                return found->second;
            } else {
                return empty_holder_;
            }
        }
        std::string &operator[](const std::string &key) {
            return std::map<std::string, std::string>::operator[](key);
        }
    };

    template <bool B>
    using bool_constant = std::integral_constant<bool, B>;
    template <class B>
    struct negation : bool_constant<!bool(B::value)> { };
    template <class T, class U>
    struct is_not_same : negation<std::is_same<T, U>> { };

    class SerialSlot {
    public:
        inline explicit SerialSlot() {}
        ~SerialSlot() {}

        template <
                typename F,
                typename ... Args,
                typename R = typename std::result_of<F(Args...)>::type,
                typename = typename std::enable_if<is_not_same<R, void>::value>::type
                >
        inline R enqueue(F &&f, Args && ... args) {
            return std::move(f(std::forward<Args>(args)...));
        }
        template <
                typename F,
                typename ... Args,
                typename R = typename std::result_of<F(Args...)>::type,
                typename = typename std::enable_if<std::is_same<R, void>::value>::type
                >
        inline void enqueue(F &&f, Args && ... args) {
            f(std::forward<Args>(args)...);
        }

        inline void NotifyAllExit() {
        }

        inline void WaitAllFree() {
        }
    };

    class ParallelSlot {
    private:
        void UsedWorkerIncrease() {
            {
                std::lock_guard<std::mutex> lock(worker_used_mtx_);
                worker_used_++;
            }
            worker_used_change_cv_.notify_all();
        }
        void UsedWorkerDecrease() {
            {
                std::lock_guard<std::mutex> lock(worker_used_mtx_);
                worker_used_--;
            }
            worker_used_change_cv_.notify_all();
        }
        void IncreaseWorker() {
            workers_.emplace_back([this]{
              while (true) {
                  std::function<void()> task;
                  {
                      std::unique_lock<std::mutex> lock(this->queue_mutex_);
                      cv_.wait(lock, [this] { return this->stop_ || !this->tasks_queue_.empty(); });
                      if (this->stop_ && this->tasks_queue_.empty()) {
                          return;
                      }
                      this->UsedWorkerIncrease();
                      task = std::move(this->tasks_queue_.front());
                      this->tasks_queue_.pop();
                      if (this->auto_extend_ && (this->worker_used_ > (this->workers_.size() - 3))) {
                          for (unsigned idx = 0, count = unsigned(std::max((float)this->workers_.size() * 0.5f, 3.0f));
                               idx < count;
                               ++idx) {
                              this->IncreaseWorker();
                          }
                      }
                      this->worker_max_used_.store(std::max(worker_max_used_.load(), this->worker_used_.load() + 1));
                  }
                  task();
                  this->UsedWorkerDecrease();
              }
            });
            max_worker_++;
        }
    public:
        inline explicit ParallelSlot(size_t max_threads, bool auto_extend = false)
                : stop_(false), auto_extend_(auto_extend) {
            worker_used_.store(0);
            worker_max_used_.store(0);
            for (size_t i = 0; i < max_threads; ++i) {
                IncreaseWorker();
            }
        }

        template <typename F, typename ... Args, typename R = typename std::result_of<F(Args...)>::type>
        inline std::future<R> enqueue(F &&f, Args && ... args) {
            auto task = std::make_shared<std::packaged_task<R()>>
                    (std::bind(std::forward<F>(f), std::forward<Args>(args)...));
            std::future<R> res = task->get_future();
            {
                std::unique_lock<std::mutex> lock(queue_mutex_);
                if (!stop_) {
                    tasks_queue_.template emplace([task](){ (*task)(); });
                }
            }
            cv_.notify_one();
            return res;
        }

        inline void WaitAllFree() {
            std::unique_lock<std::mutex> lock(worker_used_mtx_);
            worker_used_change_cv_.wait(lock, [this]() {
                return this->worker_used_ == 0;
            });
        }

        inline void NotifyAllExit() {
            {
                std::unique_lock<std::mutex> lock(queue_mutex_);
                stop_ = true;
            }
            cv_.notify_all();
            for (std::thread &worker : workers_) {
                worker.join();
            }
            workers_.clear();
        }

        ~ParallelSlot() {
            NotifyAllExit();
            std::cout << Report() << std::endl;
        }

        inline std::string Report() const {
            std::string report;
            report.append("Worker used rate: ")
                  .append(std::to_string(worker_max_used_.load()))
                  .append(" / ")
                  .append(std::to_string(max_worker_.load()));
            return report;
        }

    private:
        std::vector<std::thread> workers_;
        std::queue<std::function<void()>> tasks_queue_;
        std::mutex queue_mutex_;
        std::condition_variable cv_;
        bool stop_;
        bool auto_extend_;
        std::atomic_uint worker_max_used_{}, max_worker_{};

        std::atomic_uint        worker_used_            {};
        std::mutex              worker_used_mtx_        {};
        std::condition_variable worker_used_change_cv_  {};
    };

    class ObserverBase;
    template <bool IsObserver> struct subscriber ;

    class PubSubBase {
    private:
        std::set<std::shared_ptr<ObserverBase>> created_inside_;
        std::mutex created_inside_mtx_;
    protected:
        virtual void AddObserver(const std::string &uri, ObserverBase *observer) = 0;
    public:
        virtual void TerminalOffline(ObserverBase *terminal) = 0;
        template <typename T>
        void Subscribe(const std::string &uri, T && pred) {
            using PredDecay = typename std::decay<T>::type;
            ObserverBase *observer;
            bool auto_release;
            std::tie(observer, auto_release) =
                    subscriber<
                            std::is_convertible<PredDecay *, ObserverBase *>::value
                            >::getPtr(*this, pred);
            if (observer && auto_release) {
                std::lock_guard<std::mutex> guard(created_inside_mtx_);
                created_inside_.insert(std::shared_ptr<ObserverBase>(observer));
            }
            AddObserver(uri, observer);
        }
        virtual ~PubSubBase() {
            std::lock_guard<std::mutex> guard(created_inside_mtx_);
            created_inside_.clear();
        }
    };

    class ObserverBase {
    private:
        using lg = std::lock_guard<std::mutex>;
        std::set<PubSubBase *> pubsub_set_;
        std::mutex pubsubset_setup_mtx_;

        std::atomic_bool runable_{true};
        std::shared_ptr<std::thread> working_thread_{nullptr};
        std::mutex busy_current_mtx_;

        std::queue<KVdb> message_queue_;
        std::mutex msgq_mtx_;
        std::condition_variable msgq_cv_;
    private:
        void NotifyStop(PubSubBase *pub_sub) {
            pub_sub->TerminalOffline(this);
        }
        void NotifyAllStop() {
            for (PubSubBase *ps : pubsub_set_) {
                NotifyStop(ps);
            }
        }
        // Start a thread to wait and process the arrival of messages.
        bool Start() {
            working_thread_ = std::make_shared<std::thread>([this]() {
              KVdb kvdb;
              while (this->runable_.load()) {
                  if (this->PopMsg_WithBlocking(kvdb)) {
                      if (this->runable_.load()) {
                          lg run_lock(busy_current_mtx_);
                          this->Invoke(kvdb);
                      }
                  }
              }
            });
            return true;
        }
    public:
        ObserverBase() {
            Start();
        }
        void PushMsg(const KVdb &kvdb) {
            std::unique_lock<std::mutex> msgq_lk(msgq_mtx_);
            if (this->runable_.load()) {
                message_queue_.template emplace(kvdb);
                msgq_cv_.notify_one();
            }
        }
        bool PopMsg_WithBlocking(KVdb &kvdb) {
            std::unique_lock<std::mutex> msgq_lk(msgq_mtx_);
            msgq_cv_.wait(msgq_lk, [this]()->bool { return !this->runable_.load() || !this->message_queue_.empty(); });
            if (this->runable_.load() ==  false) {
                return false;
            }
            if (!message_queue_.empty()) {
                kvdb = std::move(message_queue_.front());
                message_queue_.pop();
                return true;
            } else {
                return false;
            }
        }
        virtual void Invoke(const KVdb &kvdb) = 0 ;

        void RegisterPubSub(PubSubBase &pub_sub) {
            lg lock(pubsubset_setup_mtx_);
            pubsub_set_.insert(&pub_sub);
        }
        void UnregisterPubSub(PubSubBase *pub_sub) {
            lg lock(pubsubset_setup_mtx_);
            pubsub_set_.erase(pub_sub);
        }

        void Disconnect(PubSubBase *pub_sub) {
            lg lock(pubsubset_setup_mtx_);
            NotifyStop(pub_sub);
            pubsub_set_.erase(pub_sub);
        }
        void DisconnectAll() {
            lg lock(pubsubset_setup_mtx_);
            NotifyAllStop();
            pubsub_set_.clear();
        }

        void DisableAndWait() {
            runable_.store(false);      // Disable first, the thread will exit in the next cycle.
            msgq_cv_.notify_all();
            if (working_thread_->joinable()) {
                working_thread_->join();
            }
        }

        virtual ~ObserverBase() {
            DisconnectAll();
            DisableAndWait();
        }
    };

    template <typename T>
    class GenericObserver : public ObserverBase {
    private:
        T pred_;
    public:
        explicit GenericObserver(T pred) : pred_(pred) {}
        void Invoke(const KVdb &kvdb) override {
            pred_(kvdb);
        }
        ~GenericObserver() override = default;
    };

    template <> struct subscriber<true> {
        static
        std::pair<ObserverBase *, bool>
        getPtr(PubSubBase &pubsub, ObserverBase &observer) {
            return {&observer, false};
        }
    };
    template <> struct subscriber<false> {
        template <typename Callable>
        static
        std::pair<ObserverBase *, bool>
        getPtr(PubSubBase &pubsub, Callable && pred) {
            ObserverBase *generic_ob =
                    new GenericObserver<Callable>(std::forward<Callable>(pred));
            return {generic_ob, true};
        }
    };

    template <typename SlotTy>
    class PubSub : public PubSubBase
                 , protected SlotTy
    {
    private:
        std::map<std::string, std::set<ObserverBase *>> pool_;
        std::mutex pool_opt_mtx_;
        using lg = std::lock_guard<std::mutex>;
    private:
        void NotifyAllStop() {
            for (auto &slot : pool_) {
                for (ObserverBase *ob : slot.second) {
                    ob->UnregisterPubSub(this);
                }
                slot.second.clear();
            }
        }
    protected:
        void AddObserver(const std::string &uri, ObserverBase *observer) override {
            if (observer) {
                lg lock(pool_opt_mtx_);
                pool_[uri].insert(observer);
                observer->RegisterPubSub(*this);
            }
        }
    public:
        template <typename ... Args>
        explicit PubSub(Args && ... args) : SlotTy(std::forward<Args>(args)...) {
        }
        ~PubSub() override {
            lg lock(pool_opt_mtx_);

            SlotTy::NotifyAllExit();

            NotifyAllStop();
            pool_.clear();
        }
        void TerminalOffline(ObserverBase *terminal) override {
            lg lock(pool_opt_mtx_);
            for (auto &item : pool_) {
                item.second.erase(terminal);
            }
        }
        void Publish(const std::string &uri, const KVdb &msg) {
            lg lock(pool_opt_mtx_);
            auto slot_it = pool_.find(uri);
            if (slot_it != pool_.end()) {
                // Dispatch the message to the slot members.
                for (ObserverBase *ob : slot_it->second) {
                    SlotTy::enqueue([ob, msg, this](){
                        ob->PushMsg(msg);
                    });
                }
                SlotTy::WaitAllFree();
            }
        }
        bool Ping(const std::string &uri) {
            lg lock(pool_opt_mtx_);
            auto slot_it = pool_.find(uri);
            return slot_it != pool_.end();
        }
    };

}

#endif//YAMQ_YAMQ_H
